(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,[可由此下載]
(https://colab.research.google.com/drive/1tIu9KwFqp7dZ_vCLOZiQ0NK_Y0av0vGF?usp=sharing)
scan 算是最為複雜的 JAX 控制流程運算子,它其實和 TensorFlow 的 tf.scan() 很像,算是其簡化的版本。如果老頭的說明不夠詳盡,使得你不能完全領會 JAX scan 的用法的話,你不妨先去翻翻 tf.scan ,這會對你很有幫助。
scan 的基本操作,是用掃描函式 f,依序以陣列元素為輸入值,產生掃描結果,最終將所有的掃描結果組成陣列回傳。
在掃描的過程中,f 也需要「傳承 carry」做為其輸入參數,並輸出下一個傳承,供掃描下一個陣列示素時使用。傳承的目的,是提供掃描程式執行時的「環境脈絡 context」,例如目前處理的陣列元素它的索引值等。
f :
init :
xs:
length:
reverse:
unroll:
回傳值:
scan 的運作流程,可以用以下的 Python 程式段來說明:
def scan(f, init, xs, length=None):
if xs is None:
xs = [None] * length
carry = init
ys = []
for x in xs:
carry, y = f(carry, x)
ys.append(y)
return carry, np.stack(ys)
計算陣列的累積和 (cumulative sum)
計算 1 加到 10 的和
這個例子說明的輸入陣列 xs 為 none,藉由指定 length 的值來達到目的。
可以看出來,當設定較大的 unroll 值後,scan 的執行速度會略為加快。